{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9e9b7651",
   "metadata": {},
   "source": [
    "# How to create a custom LLM class\n",
    "\n",
    "This notebook goes over how to create a custom LLM wrapper, in case you want to use your own LLM or a different wrapper than one that is supported in LangChain.\n",
    "\n",
    "Wrapping your LLM with the standard `LLM` interface allow you to use your LLM in existing LangChain programs with minimal code modifications!\n",
    "\n",
    "As an bonus, your LLM will automatically become a LangChain `Runnable` and will benefit from some optimizations out of the box, async support, the `astream_events` API, etc.\n",
    "\n",
    "## Implementation\n",
    "\n",
    "There are only two required things that a custom LLM needs to implement:\n",
    "\n",
    "\n",
    "| Method        | Description                                                               |\n",
    "|---------------|---------------------------------------------------------------------------|\n",
    "| `_call`       | Takes in a string and some optional stop words, and returns a string. Used by `invoke`. |\n",
    "| `_llm_type`   | A property that returns a string, used for logging purposes only.        \n",
    "\n",
    "\n",
    "\n",
    "Optional implementations: \n",
    "\n",
    "\n",
    "| Method    | Description                                                                                               |\n",
    "|----------------------|-----------------------------------------------------------------------------------------------------------|\n",
    "| `_identifying_params` | Used to help with identifying the model and printing the LLM; should return a dictionary. This is a **@property**.                 |\n",
    "| `_acall`              | Provides an async native implementation of `_call`, used by `ainvoke`.                                    |\n",
    "| `_stream`             | Method to stream the output token by token.                                                               |\n",
    "| `_astream`            | Provides an async native implementation of `_stream`; in newer LangChain versions, defaults to `_stream`. |\n",
    "\n",
    "\n",
    "\n",
    "Let's implement a simple custom LLM that just returns the first n characters of the input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2e9bb32f-6fd1-46ac-b32f-d175663710c0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from typing import Any, Dict, Iterator, List, Mapping, Optional\n",
    "\n",
    "from langchain_core.callbacks.manager import CallbackManagerForLLMRun\n",
    "from langchain_core.language_models.llms import LLM\n",
    "from langchain_core.outputs import GenerationChunk\n",
    "\n",
    "\n",
    "class CustomLLM(LLM):\n",
    "    \"\"\"A custom chat model that echoes the first `n` characters of the input.\n",
    "\n",
    "    When contributing an implementation to LangChain, carefully document\n",
    "    the model including the initialization parameters, include\n",
    "    an example of how to initialize the model and include any relevant\n",
    "    links to the underlying models documentation or API.\n",
    "\n",
    "    Example:\n",
    "\n",
    "        .. code-block:: python\n",
    "\n",
    "            model = CustomChatModel(n=2)\n",
    "            result = model.invoke([HumanMessage(content=\"hello\")])\n",
    "            result = model.batch([[HumanMessage(content=\"hello\")],\n",
    "                                 [HumanMessage(content=\"world\")]])\n",
    "    \"\"\"\n",
    "\n",
    "    n: int\n",
    "    \"\"\"The number of characters from the last message of the prompt to be echoed.\"\"\"\n",
    "\n",
    "    def _call(\n",
    "        self,\n",
    "        prompt: str,\n",
    "        stop: Optional[List[str]] = None,\n",
    "        run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
    "        **kwargs: Any,\n",
    "    ) -> str:\n",
    "        \"\"\"Run the LLM on the given input.\n",
    "\n",
    "        Override this method to implement the LLM logic.\n",
    "\n",
    "        Args:\n",
    "            prompt: The prompt to generate from.\n",
    "            stop: Stop words to use when generating. Model output is cut off at the\n",
    "                first occurrence of any of the stop substrings.\n",
    "                If stop tokens are not supported consider raising NotImplementedError.\n",
    "            run_manager: Callback manager for the run.\n",
    "            **kwargs: Arbitrary additional keyword arguments. These are usually passed\n",
    "                to the model provider API call.\n",
    "\n",
    "        Returns:\n",
    "            The model output as a string. Actual completions SHOULD NOT include the prompt.\n",
    "        \"\"\"\n",
    "        if stop is not None:\n",
    "            raise ValueError(\"stop kwargs are not permitted.\")\n",
    "        return prompt[: self.n]\n",
    "\n",
    "    def _stream(\n",
    "        self,\n",
    "        prompt: str,\n",
    "        stop: Optional[List[str]] = None,\n",
    "        run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
    "        **kwargs: Any,\n",
    "    ) -> Iterator[GenerationChunk]:\n",
    "        \"\"\"Stream the LLM on the given prompt.\n",
    "\n",
    "        This method should be overridden by subclasses that support streaming.\n",
    "\n",
    "        If not implemented, the default behavior of calls to stream will be to\n",
    "        fallback to the non-streaming version of the model and return\n",
    "        the output as a single chunk.\n",
    "\n",
    "        Args:\n",
    "            prompt: The prompt to generate from.\n",
    "            stop: Stop words to use when generating. Model output is cut off at the\n",
    "                first occurrence of any of these substrings.\n",
    "            run_manager: Callback manager for the run.\n",
    "            **kwargs: Arbitrary additional keyword arguments. These are usually passed\n",
    "                to the model provider API call.\n",
    "\n",
    "        Returns:\n",
    "            An iterator of GenerationChunks.\n",
    "        \"\"\"\n",
    "        for char in prompt[: self.n]:\n",
    "            chunk = GenerationChunk(text=char)\n",
    "            if run_manager:\n",
    "                run_manager.on_llm_new_token(chunk.text, chunk=chunk)\n",
    "\n",
    "            yield chunk\n",
    "\n",
    "    @property\n",
    "    def _identifying_params(self) -> Dict[str, Any]:\n",
    "        \"\"\"Return a dictionary of identifying parameters.\"\"\"\n",
    "        return {\n",
    "            # The model name allows users to specify custom token counting\n",
    "            # rules in LLM monitoring applications (e.g., in LangSmith users\n",
    "            # can provide per token pricing for their model and monitor\n",
    "            # costs for the given LLM.)\n",
    "            \"model_name\": \"CustomChatModel\",\n",
    "        }\n",
    "\n",
    "    @property\n",
    "    def _llm_type(self) -> str:\n",
    "        \"\"\"Get the type of language model used by this chat model. Used for logging purposes only.\"\"\"\n",
    "        return \"custom\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f614fb7b-e476-4d81-821b-57a2ebebe21c",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Let's test it 🧪"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3feae15-4afc-49f4-8542-93867d4ea769",
   "metadata": {
    "tags": []
   },
   "source": [
    "This LLM will implement the standard `Runnable` interface of LangChain which many of the LangChain abstractions support!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dfff4a95-99b2-4dba-b80d-9c3855046ef1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1mCustomLLM\u001b[0m\n",
      "Params: {'model_name': 'CustomChatModel'}\n"
     ]
    }
   ],
   "source": [
    "llm = CustomLLM(n=5)\n",
    "print(llm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8cd49199",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'This '"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llm.invoke(\"This is a foobar thing\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "511b3cb1-9c6f-49b6-9002-a2ec490632b0",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'world'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await llm.ainvoke(\"world\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d9d5bec2-d60a-4ebd-a97d-ac32c98ab02f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['woof ', 'meow ']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llm.batch([\"woof woof woof\", \"meow meow meow\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fe246b29-7a93-4bef-8861-389445598c25",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['woof ', 'meow ']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await llm.abatch([\"woof woof woof\", \"meow meow meow\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3a67c38f-b83b-4eb9-a231-441c55ee8c82",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "h|e|l|l|o|"
     ]
    }
   ],
   "source": [
    "async for token in llm.astream(\"hello\"):\n",
    "    print(token, end=\"|\", flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b62c282b-3a35-4529-aac4-2c2f0916790e",
   "metadata": {},
   "source": [
    "Let's confirm that in integrates nicely with other `LangChain` APIs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d5578e74-7fa8-4673-afee-7a59d442aaff",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from langchain_core.prompts import ChatPromptTemplate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "672ff664-8673-4832-9f4f-335253880141",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [(\"system\", \"you are a bot\"), (\"human\", \"{input}\")]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c400538a-9146-4c93-9fac-293d8f9ca6bf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "llm = CustomLLM(n=7)\n",
    "chain = prompt | llm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "080964af-3e2d-4573-85cb-0d7cc58a6f42",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'event': 'on_chain_start', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'name': 'RunnableSequence', 'tags': [], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}\n",
      "{'event': 'on_prompt_start', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}\n",
      "{'event': 'on_prompt_end', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}, 'output': ChatPromptValue(messages=[SystemMessage(content='you are a bot'), HumanMessage(content='hello there!')])}}\n",
      "{'event': 'on_llm_start', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'input': {'prompts': ['System: you are a bot\\nHuman: hello there!']}}}\n",
      "{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'S'}}\n",
      "{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'S'}}\n",
      "{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'y'}}\n",
      "{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'y'}}\n"
     ]
    }
   ],
   "source": [
    "idx = 0\n",
    "async for event in chain.astream_events({\"input\": \"hello there!\"}, version=\"v1\"):\n",
    "    print(event)\n",
    "    idx += 1\n",
    "    if idx > 7:\n",
    "        # Truncate\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a85e848a-5316-4318-b770-3f8fd34f4231",
   "metadata": {},
   "source": [
    "## Contributing\n",
    "\n",
    "We appreciate all chat model integration contributions. \n",
    "\n",
    "Here's a checklist to help make sure your contribution gets added to LangChain:\n",
    "\n",
    "Documentation:\n",
    "\n",
    "* The model contains doc-strings for all initialization arguments, as these will be surfaced in the [APIReference](https://api.python.langchain.com/en/stable/langchain_api_reference.html).\n",
    "* The class doc-string for the model contains a link to the model API if the model is powered by a service.\n",
    "\n",
    "Tests:\n",
    "\n",
    "* [ ] Add unit or integration tests to the overridden methods. Verify that `invoke`, `ainvoke`, `batch`, `stream` work if you've over-ridden the corresponding code.\n",
    "\n",
    "Streaming (if you're implementing it):\n",
    "\n",
    "* [ ] Make sure to invoke the `on_llm_new_token` callback\n",
    "* [ ] `on_llm_new_token` is invoked BEFORE yielding the chunk\n",
    "\n",
    "Stop Token Behavior:\n",
    "\n",
    "* [ ] Stop token should be respected\n",
    "* [ ] Stop token should be INCLUDED as part of the response\n",
    "\n",
    "Secret API Keys:\n",
    "\n",
    "* [ ] If your model connects to an API it will likely accept API keys as part of its initialization. Use Pydantic's `SecretStr` type for secrets, so they don't get accidentally printed out when folks print the model."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
